Skip to content

Align Int4Tensor implementation details with the design of Float8Tensor #2687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2025

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 5, 2025

Stacked PRs:


Align Int4Tensor implementation details with the design of Float8Tensor

Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)

Note: This is just refactoring Int4Tensor, no BC related changes in this PR

Int4Tensor path is exposed in version 2 of Int4WeightOnlyConfig (default version is still 1, which is using the old AQT path

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2687

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ba62d8e with merge base c086ade (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1d84542 to 4874773 Compare August 5, 2025 03:25
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2025
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 5, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 18:39
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 4874773 to 1beccb0 Compare August 5, 2025 18:39
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 18:39
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
input_tensor,
weight_tensor._data.contiguous(),
weight_tensor.qdata.contiguous(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected that the tensors are not contiguous? if not, can we assert for this instead of calling contiguous?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the non-contiguous comes from the reshape ops like transpose, view I think, but the kernel will need these to be contiguous, I can try changing these to assert and do the contiguous operation in user side to see if it works

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected the weights to be stored in a format aligned with what the kernel needs, without any need for just-in-time layout transforms. Does this match how the current code works?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally it is, but the weights also goes through some transformations like the ones we listed in test_moe_weight_reshape_ops which makes weight / scale etc. non-contiguous I think, but I can try to do call contiguous in user code, that might be cleaner I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out the contiguous is not implemented properly, just fixed that and we can remove contiguous calls in linear/bmm now

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 5, 2025 23:30
jerryzh168 added a commit that referenced this pull request Aug 5, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 1beccb0 to 5f6306e Compare August 5, 2025 23:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/10 August 5, 2025 23:30


@register_quantize_module_handler(TestOnlyMoEQuantConfig)
def moe_quant_fn(module, config: TestOnlyMoEQuantConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really confusing, could you share the result of print(model) after this function has been applied?

if it's going to print model with parameters wrapped in Int4Tensor, can we just wrap the parameters directly without all of these layers of abstraction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is working around the fact that quantize_ needs to work on modules, IMO we should change quantize_ to handle this instead of working around? seems important for MoEs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the parameters are wrapped in Int4Tensor, this is just applying quantization to each of the moe weights: w1, w2 and w3

I can inline these for now. can follow up with how to have an API for weights + configs separately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not worth changing API right now since MoE quant is also moving, let me know if current code looks good

@@ -177,3 +178,63 @@ def create_model_and_input_data(
else:
raise ValueError(f"Unknown model type: {model_type}")
return model, input_data


class Experts(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call it something like FeedForwardWithExperts? Experts is ambiguous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is adapted from https://github.com/meta-llama/llama-models/blob/a9c89c471f793423afd4cc3ca8671d6e56fe64cb/models/llama4/moe.py#L22, how about renaming to LLama4Experts to make it more specific

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/10 to main August 6, 2025 01:07
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 5f6306e to 6bd3106 Compare August 6, 2025 01:08
jerryzh168 added a commit that referenced this pull request Aug 6, 2025
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* Added VERSION 2 for Int4WeightOnlyConfig
* Migrated op implementation and tests from #2387

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch 2 times, most recently from 65cdd3e to 28bd29c Compare August 8, 2025 04:19
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 8, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 8, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 8, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 requested a review from drisspg August 8, 2025 23:59
tensor_data_attrs = ["_data", "scale", "zero_point"]
tensor_attributes = ["block_size", "shape"]
tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_attribute_names is a lil weird to me
are these attributes that are tensors and thus should go int he right unflatten location?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah.. this somewhat follows what unflatten and flatten functions names these things, the tensor means the tensor subclass instance, meaning the attributes of the tensor subclass instance, instead of "tensor attributes"

I could remove tensor_ as well to make it less confusing? probably better to do in a separate PR



@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestInt4Tensor(TestCase):
class TestInt4Tensor(TorchAOIntegrationTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we keeping both old and new? If we are keeping the old version working, I would expect this test case to not have any changes, as it would test the old version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old version meaning the version using AQT? we are keeping AQT, but this test does not test the AQT path, it only tests the new Int4Tensor, and we are updating Int4Tensor in this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update the PR summary with context on this? Migrations are always confusing and clearly laying out what is changing with BC, what is breaking BC, and what is not changing will help get a good review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK added, no BC related changes in this PR

jerryzh168 added a commit that referenced this pull request Aug 11, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 28bd29c to 820f264 Compare August 11, 2025 20:36
jerryzh168 added a commit that referenced this pull request Aug 11, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2738, branch: jerryzh168/stack/20
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 11, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 11, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 11, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit that referenced this pull request Aug 12, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Note: This is just refactoring Int4PreshuffledTensor, no BC related changes in this PR

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2738, branch: jerryzh168/stack/20
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 820f264 to 640ba3b Compare August 12, 2025 02:49
Summary:
Int4Tensor is the non-preshuffled version of int4 quantized Tensor, data is [N, K/2], scale/zero_point has shape: [K/group_size, N]

Multiple fixes for Int4Tensor to align with the design of Float8Tensor (only calling fbgemm ops)
* defined `tensor_data_names` and `tensor_attribute_names` so we can remove some of the implementations from TorchAOBaseTensor
* Migrated op implementation and tests from #2387

Note: This is just refactoring Int4Tensor, no BC related changes in this PR

Int4Tensor path is exposed in version 2 of `Int4WeightOnlyConfig` (default version is still 1, which is using the old AQT path

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2687, branch: jerryzh168/stack/16
jerryzh168 added a commit that referenced this pull request Aug 12, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Note: This is just refactoring Int4PreshuffledTensor, no BC related changes in this PR

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2738, branch: jerryzh168/stack/20
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/16 branch from 640ba3b to ba62d8e Compare August 12, 2025 02:52
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 12, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 merged commit 0b88286 into main Aug 12, 2025
18 checks passed
jerryzh168 added a commit that referenced this pull request Aug 12, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Note: This is just refactoring Int4PreshuffledTensor, no BC related changes in this PR

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2738, branch: jerryzh168/stack/20
jerryzh168 added a commit that referenced this pull request Aug 12, 2025
…the Float8Tensor

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Note: This is just refactoring Int4PreshuffledTensor, no BC related changes in this PR

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2738, branch: jerryzh168/stack/20
jerryzh168 added a commit that referenced this pull request Aug 12, 2025
…the Float8Tensor (#2738)

Summary:
similar to #2687, we updated Int4PreshuffledTensor to align
the implementation details, also used TorchAOBaseTensor to simplify some of the implementations

Note: This is just refactoring Int4PreshuffledTensor, no BC related changes in this PR

Test Plan:
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/ao that referenced this pull request Aug 12, 2025
Summary:
We have recently updated our design for structuring tensor subclasses in torchao
to remove unnecessary abstractions and reduce indirections and having a structuring that
aligns better with people's intuitive understanding of different quantization use cases,
examples using the new design are: pytorch#2463, pytorch#2687

Test Plan:
check generated doc
Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants